import torch
import torch.nn as nn

def focal_loss(outputs, targets, alpha = 0.25, gamma = 2):
    BCE_loss = nn.CrossEntropyLoss(reduction='none')(outputs, targets)
    pt = torch.exp(-BCE_loss)
    F_loss = alpha * (1 - pt) ** gamma * BCE_loss
    return F_loss.mean()
